{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import os, re\n",
    "from tensorflow.python.layers.core import Dense"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "MAX_CHAR_PER_LINE = 20\n",
    "\n",
    "def load_sentences(path):\n",
    "    with open(path, 'r', encoding=\"ISO-8859-1\") as f:\n",
    "        data_raw = f.read().encode('ascii', 'ignore').decode('UTF-8').lower()\n",
    "        data_alpha = re.sub('[^a-z\\n]+', ' ', data_raw)\n",
    "        data = []\n",
    "        for line in data_alpha.split('\\n'):\n",
    "            data.append(line[:MAX_CHAR_PER_LINE])\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_character_vocab(data):\n",
    "    special_symbols = ['<PAD>', '<UNK>', '<GO>',  '<EOS>']\n",
    "    set_symbols = set([character for line in data for character in line])\n",
    "    all_symbols = special_symbols + list(set_symbols)\n",
    "    int_to_symbol = {word_i: word for word_i, word in enumerate(all_symbols)}\n",
    "    symbol_to_int = {word: word_i for word_i, word in int_to_symbol.items()}\n",
    "    return int_to_symbol, symbol_to_int\n",
    "\n",
    "input_sentences = load_sentences('data/words_input.txt')  \n",
    "output_sentences = load_sentences('data/words_output.txt')  \n",
    "\n",
    "input_int_to_symbol, input_symbol_to_int = extract_character_vocab(input_sentences)\n",
    "output_int_to_symbol, output_symbol_to_int = extract_character_vocab(output_sentences)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: '<PAD>',\n",
       " 1: '<UNK>',\n",
       " 2: '<GO>',\n",
       " 3: '<EOS>',\n",
       " 4: 's',\n",
       " 5: 'n',\n",
       " 6: 'q',\n",
       " 7: 'f',\n",
       " 8: 'v',\n",
       " 9: 'g',\n",
       " 10: 'm',\n",
       " 11: 'w',\n",
       " 12: 'd',\n",
       " 13: 'i',\n",
       " 14: 'o',\n",
       " 15: 'a',\n",
       " 16: 'r',\n",
       " 17: 'y',\n",
       " 18: 'j',\n",
       " 19: 'b',\n",
       " 20: 'c',\n",
       " 21: ' ',\n",
       " 22: 'u',\n",
       " 23: 'p',\n",
       " 24: 'e',\n",
       " 25: 'k',\n",
       " 26: 'h',\n",
       " 27: 't',\n",
       " 28: 'z',\n",
       " 29: 'l',\n",
       " 30: 'x'}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_int_to_symbol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: '<PAD>',\n",
       " 1: '<UNK>',\n",
       " 2: '<GO>',\n",
       " 3: '<EOS>',\n",
       " 4: 's',\n",
       " 5: 'n',\n",
       " 6: 'q',\n",
       " 7: 'f',\n",
       " 8: 'v',\n",
       " 9: 'g',\n",
       " 10: 'm',\n",
       " 11: 'w',\n",
       " 12: 'd',\n",
       " 13: 'i',\n",
       " 14: 'o',\n",
       " 15: 'a',\n",
       " 16: 'r',\n",
       " 17: 'y',\n",
       " 18: 'j',\n",
       " 19: 'b',\n",
       " 20: 'c',\n",
       " 21: ' ',\n",
       " 22: 'u',\n",
       " 23: 'p',\n",
       " 24: 'e',\n",
       " 25: 'k',\n",
       " 26: 'h',\n",
       " 27: 't',\n",
       " 28: 'z',\n",
       " 29: 'l',\n",
       " 30: 'x'}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output_int_to_symbol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "NUM_EPOCS = 300\n",
    "RNN_STATE_DIM = 512\n",
    "RNN_NUM_LAYERS = 2\n",
    "ENCODER_EMBEDDING_DIM = DECODER_EMBEDDING_DIM = 64\n",
    " \n",
    "BATCH_SIZE = int(32)\n",
    "LEARNING_RATE = 0.0003\n",
    " \n",
    "INPUT_NUM_VOCAB = len(input_symbol_to_int)\n",
    "OUTPUT_NUM_VOCAB = len(output_symbol_to_int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Encoder placeholders\n",
    "encoder_input_seq = tf.placeholder(\n",
    "    tf.int32, \n",
    "    [None, None],  \n",
    "    name='encoder_input_seq'\n",
    ")\n",
    "\n",
    "encoder_seq_len = tf.placeholder(\n",
    "    tf.int32, \n",
    "    (None,), \n",
    "    name='encoder_seq_len'\n",
    ")\n",
    " \n",
    "# Decoder placeholders\n",
    "decoder_output_seq = tf.placeholder( \n",
    "    tf.int32, \n",
    "    [None, None],\n",
    "    name='decoder_output_seq'\n",
    ")\n",
    "\n",
    "decoder_seq_len = tf.placeholder(\n",
    "    tf.int32,\n",
    "    (None,), \n",
    "    name='decoder_seq_len'\n",
    ")\n",
    "\n",
    "max_decoder_seq_len = tf.reduce_max( \n",
    "    decoder_seq_len, \n",
    "    name='max_decoder_seq_len'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def make_cell(state_dim):\n",
    "    lstm_initializer = tf.random_uniform_initializer(-0.1, 0.1)\n",
    "    return tf.contrib.rnn.LSTMCell(state_dim, initializer=lstm_initializer)\n",
    " \n",
    "def make_multi_cell(state_dim, num_layers):\n",
    "    cells = [make_cell(state_dim) for _ in range(num_layers)]\n",
    "    return tf.contrib.rnn.MultiRNNCell(cells)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Encoder embedding\n",
    " \n",
    "encoder_input_embedded = tf.contrib.layers.embed_sequence(\n",
    "    encoder_input_seq,     \n",
    "    INPUT_NUM_VOCAB,        \n",
    "    ENCODER_EMBEDDING_DIM  \n",
    ")\n",
    " \n",
    " \n",
    "# Encodering output\n",
    " \n",
    "encoder_multi_cell = make_multi_cell(RNN_STATE_DIM, RNN_NUM_LAYERS)\n",
    " \n",
    "encoder_output, encoder_state = tf.nn.dynamic_rnn(\n",
    "    encoder_multi_cell, \n",
    "    encoder_input_embedded, \n",
    "    sequence_length=encoder_seq_len, \n",
    "    dtype=tf.float32\n",
    ")\n",
    " \n",
    "del(encoder_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "decoder_raw_seq = decoder_output_seq[:, :-1]  \n",
    "go_prefixes = tf.fill([BATCH_SIZE, 1], output_symbol_to_int['<GO>'])  \n",
    "decoder_input_seq = tf.concat([go_prefixes, decoder_raw_seq], 1)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder_embedding = tf.Variable(tf.random_uniform([OUTPUT_NUM_VOCAB, \n",
    "                                                   DECODER_EMBEDDING_DIM]))\n",
    "decoder_input_embedded = tf.nn.embedding_lookup(decoder_embedding, \n",
    "                                                decoder_input_seq)\n",
    "\n",
    "decoder_multi_cell = make_multi_cell(RNN_STATE_DIM, RNN_NUM_LAYERS)\n",
    " \n",
    "output_layer_kernel_initializer = tf.truncated_normal_initializer(mean=0.0, stddev=0.1)\n",
    "output_layer = Dense(\n",
    "    OUTPUT_NUM_VOCAB,\n",
    "    kernel_initializer = output_layer_kernel_initializer\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.variable_scope(\"decode\"):\n",
    " \n",
    "    training_helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "        inputs=decoder_input_embedded,\n",
    "        sequence_length=decoder_seq_len,\n",
    "        time_major=False\n",
    "    )\n",
    " \n",
    "    training_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "        decoder_multi_cell,\n",
    "        training_helper,\n",
    "        encoder_state,\n",
    "        output_layer\n",
    "    ) \n",
    " \n",
    "    training_decoder_output_seq, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "        training_decoder, \n",
    "        impute_finished=True, \n",
    "        maximum_iterations=max_decoder_seq_len\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with tf.variable_scope(\"decode\", reuse=True):\n",
    "    start_tokens = tf.tile(\n",
    "        tf.constant([output_symbol_to_int['<GO>']], \n",
    "                    dtype=tf.int32), \n",
    "        [BATCH_SIZE], \n",
    "        name='start_tokens')\n",
    " \n",
    "    # Helper for the inference process.\n",
    "    inference_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(\n",
    "        embedding=decoder_embedding,\n",
    "        start_tokens=start_tokens,\n",
    "        end_token=output_symbol_to_int['<EOS>']\n",
    "    )\n",
    " \n",
    "    # Basic decoder\n",
    "    inference_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "        decoder_multi_cell,\n",
    "        inference_helper,\n",
    "        encoder_state,\n",
    "        output_layer\n",
    "    )\n",
    " \n",
    "    # Perform dynamic decoding using the decoder\n",
    "    inference_decoder_output_seq, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "        inference_decoder,\n",
    "        impute_finished=True,\n",
    "        maximum_iterations=max_decoder_seq_len\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# rename the tensor for our convenience\n",
    "training_logits = tf.identity(training_decoder_output_seq.rnn_output, name='logits')\n",
    "inference_logits = tf.identity(inference_decoder_output_seq.sample_id, name='predictions')\n",
    " \n",
    "# Create the weights for sequence_loss\n",
    "masks = tf.sequence_mask(\n",
    "    decoder_seq_len, \n",
    "    max_decoder_seq_len, \n",
    "    dtype=tf.float32, \n",
    "    name='masks'\n",
    ")\n",
    " \n",
    "cost = tf.contrib.seq2seq.sequence_loss(\n",
    "    training_logits,\n",
    "    decoder_output_seq,\n",
    "    masks\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = tf.train.AdamOptimizer(LEARNING_RATE)\n",
    " \n",
    "gradients = optimizer.compute_gradients(cost)\n",
    "capped_gradients = [(tf.clip_by_value(grad, -5., 5.), var)\n",
    "                        for grad, var in gradients if grad is not None]\n",
    "train_op = optimizer.apply_gradients(capped_gradients)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def pad(xs, size, pad):\n",
    "    return xs + [pad] * (size - len(xs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "input_seq = [\n",
    "    [input_symbol_to_int.get(symbol, input_symbol_to_int['<UNK>']) \n",
    "        for symbol in line]  \n",
    "    for line in input_sentences  \n",
    "]\n",
    " \n",
    "output_seq = [\n",
    "    [output_symbol_to_int.get(symbol, output_symbol_to_int['<UNK>']) \n",
    "        for symbol in line] + [output_symbol_to_int['<EOS>']]  \n",
    "    for line in output_sentences  \n",
    "]\n",
    "\n",
    "sess = tf.InteractiveSession()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "saver = tf.train.Saver()  \n",
    " \n",
    "for epoch in range(NUM_EPOCS + 1):  \n",
    "    for batch_idx in range(len(input_sentences) // BATCH_SIZE): \n",
    "        \n",
    "        input_batch, input_lengths, output_batch, output_lengths = [], [], [], []\n",
    "        for sentence in input_sentences[batch_idx:batch_idx + BATCH_SIZE]:\n",
    "            symbol_sent = [input_symbol_to_int[symbol] for symbol in sentence]\n",
    "            padded_symbol_sent = pad(symbol_sent, MAX_CHAR_PER_LINE, input_symbol_to_int['<PAD>'])\n",
    "            input_batch.append(padded_symbol_sent)\n",
    "            input_lengths.append(len(sentence))\n",
    "        for sentence in output_sentences[batch_idx:batch_idx + BATCH_SIZE]:\n",
    "            symbol_sent = [output_symbol_to_int[symbol] for symbol in sentence]\n",
    "            padded_symbol_sent = pad(symbol_sent, MAX_CHAR_PER_LINE, output_symbol_to_int['<PAD>'])\n",
    "            output_batch.append(padded_symbol_sent)\n",
    "            output_lengths.append(len(sentence))\n",
    "\n",
    "        _, cost_val = sess.run( \n",
    "            [train_op, cost],\n",
    "            feed_dict={\n",
    "                encoder_input_seq: input_batch,\n",
    "                encoder_seq_len: input_lengths,\n",
    "                decoder_output_seq: output_batch,\n",
    "                decoder_seq_len: output_lengths\n",
    "            }\n",
    "        )\n",
    "        \n",
    "        if batch_idx % 629 == 0:\n",
    "            print('Epcoh {}. Batch {}/{}. Cost {}'.format(epoch, batch_idx, len(input_sentences) // BATCH_SIZE, cost_val))\n",
    "\n",
    "    saver.save(sess, 'model.ckpt')   \n",
    "sess.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from model.ckpt\n",
      "indeed just one of that r\n"
     ]
    }
   ],
   "source": [
    "sess = tf.InteractiveSession()    \n",
    "saver.restore(sess, 'model.ckpt')\n",
    "\n",
    "example_input_sent = \"do you want to play games\"\n",
    "example_input_symb = [input_symbol_to_int[symbol] for symbol in example_input_sent]\n",
    "example_input_batch = [pad(example_input_symb, MAX_CHAR_PER_LINE, input_symbol_to_int['<PAD>'])] * BATCH_SIZE\n",
    "example_input_lengths = [len(example_input_sent)] * BATCH_SIZE\n",
    "\n",
    "output_ints = sess.run(inference_logits, feed_dict={\n",
    "    encoder_input_seq: example_input_batch,\n",
    "    encoder_seq_len: example_input_lengths,\n",
    "    decoder_seq_len: example_input_lengths\n",
    "})[0]\n",
    "\n",
    "output_str = ''.join([output_int_to_symbol[i] for i in output_ints])\n",
    "print(output_str)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
